%% Aim
% Predict choice and rt based on DDM estimates

%% Setup
clc;
clear;
tic;
root_path = pwd; 
addpath(genpath(root_path));

nrole=2;
nprice=10;
nlottery=10;

role_label={'Buyer','Seller'};
role_color_dark={[235/255 112/255 99/255],[106/255 142/255 208/255]}; 
role_color_median={[238/255 137/255 126/255],[144/255 171/255 220/255]}; 
role_color_light={[243/255 171/255 163/255],[165/255 187/255 227/255]};

bias_label={'Val','Res','Val vs. Res'};
bias_color_median={[195/255 205/255 147/255],[192/255 173/255 203/255]}; 

%% Load data
beh_data=readtable([root_path,'\data\behavior\behavior.csv']);
sub_data=readtable([root_path,'\data\subject\subject.csv']);
nsub=height(sub_data);

hddm_data=readtable([root_path,'\model\hddm\DDM4a\DDM4a_hddm_rt_simulated.csv']);

%% Incorporate date

beh_data.RTSim=hddm_data.RTSim;

beh_data.RTPred = abs(beh_data.RTSim);
beh_data.EndingPred = beh_data.RTSim>0;
beh_data.ChoicePred = (2-beh_data.Role) .* beh_data.EndingPred + (beh_data.Role-1) .* (1-beh_data.EndingPred);

%% Predicted Probability of trading: Matrix figure (group-level)

    for irole=1:nrole
        for iprice=1:nprice
            for ilottery=1:nlottery
                bdata=beh_data(  beh_data.Role==irole & ...
                                 beh_data.Price==iprice & ...
                                 beh_data.Lottery==ilottery*2,:);
                prob_accept.matrix.group.role(irole).value(ilottery,iprice)=nanmean(bdata.ChoicePred);
            end
        end
    end

    figure('Renderer', 'painters', 'Position', [10 10 1000 350]);
    for irole=1:nrole
        subplot(1,2,irole);
        A=prob_accept.matrix.group.role(irole).value;
        h=imagesc(A);
        cmap_up=[linspace(1,192/256,1000)',linspace(1,0,1000)',linspace(1,0,1000)']; 
        cmap_down=[linspace(0/256,1,1000)',linspace(32/256,1,1000)',linspace(102/256,1,1000)']; 
        cmap=[cmap_down;cmap_up];
        colormap(cmap);
        set(gca,'TickDir','out');
        xlabel('Price');
        ylabel('Lottery');
        yticks(1:nprice);
        xticks(1:nlottery);
        yticklabels({'2','4','6','8','10','12','14','16','18','20'});
        xticklabels({'1','2','3','4','5','6','7','8','9','10'});
        colorbar;
        caxis([0 1]);
        title(role_label{irole});
        hold on
    end
    hold off
    print(1, '-dtiff', [root_path, '\figure\behavior\choice\matrix\prob_accept_matrix_group_pred_ddm.tif'], '-r200');
    close;       

%% Raw Choice vs. Predicted Choice
for isub=1:height(sub_data)
    bdata=beh_data(beh_data.Subject==isub,:);
    sub_data.DDMChoicePred(isub)=mean(bdata.Choice==bdata.ChoicePred); 
end   

mean(sub_data.DDMChoicePred);
std(sub_data.DDMChoicePred);
min(sub_data.DDMChoicePred);
max(sub_data.DDMChoicePred);

%% Predicted RT: Matrix figure (group-level)

    for irole=1:nrole
        for iprice=1:nprice
            for ilottery=1:nlottery
                bdata=beh_data(  beh_data.Role==irole & ...
                                 beh_data.Price==iprice & ...
                                 beh_data.Lottery==ilottery*2,:);
                rt.matrix.group.role(irole).value(ilottery,iprice)=nanmean(bdata.RTPred);
            end
        end
    end

    figure('Renderer', 'painters', 'Position', [10 10 1000 350]);
    for irole=1:nrole
        subplot(1,2,irole);
        A=rt.matrix.group.role(irole).value;
        h=imagesc(A);
        cmap_up=[linspace(1,192/256,1000)',linspace(1,0,1000)',linspace(1,0,1000)']; 
        cmap_down=[linspace(0/256,1,1000)',linspace(32/256,1,1000)',linspace(102/256,1,1000)']; 
        cmap=[cmap_down;cmap_up];
        colormap(cmap);
        set(gca,'TickDir','out');
        xlabel('Price');
        ylabel('Lottery');
        yticks(1:nprice);
        xticks(1:nlottery);
        yticklabels({'2','4','6','8','10','12','14','16','18','20'});
        xticklabels({'1','2','3','4','5','6','7','8','9','10'});
        colorbar;
        if irole==1
            caxis([0.9 1.2]);
        else
            caxis([0.9 1.4]);
        end
        title(role_label{irole});
        hold on
    end
    hold off
    print(1, '-dtiff', [root_path, '\figure\behavior\rt\matrix\rt_matrix_group_pred_ddm.tif'], '-r200');
    close;       

%% Raw RT vs. Predicted RT
for isub=1:height(sub_data)
    bdata=beh_data(beh_data.Subject==isub,:);
    [sub_data.DDMRTPred(isub),p]=corr(bdata.RT, bdata.RTPred); 
end   

mean(sub_data.DDMRTPred);
std(sub_data.DDMRTPred);
min(sub_data.DDMRTPred);
max(sub_data.DDMRTPred);

[h,p,ci,stats]=ttest(sub_data.DDMRTPred,0);%,'tail','right');

%% save
writetable(sub_data,[root_path,'\data\subject\subject_ddm_pred.csv']);

%%
toc;
